import torch

x = torch.ones(1, 1, requires_grad=True)
print(x)
y1 = x*x
with torch.no_grad():
    y2 = x*x*x
y3 = y1 + y2
out = y3.mean()
out.backward()
print(x.grad)
